{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "rILQ3zfUD688"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from models.crnn import model\n",
    "from models.ctc_loss import CTCLoss\n",
    "from models.accuracy import WordAccuracy\n",
    "from models.config import BATCH_SIZE, BUFFER_SIZE, WORK_PATH\n",
    "from models.data_prepare import load_and_preprocess_image, decode_label, get_image_path\n",
    "import numpy as np\n",
    "import json\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "当前可用GPU数量：  0\n"
     ]
    }
   ],
   "source": [
    "print(\"当前可用GPU数量： \", len(tf.config.experimental.list_physical_devices('GPU')))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 一、数据集准备"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1、获取并划分训练集、验证集 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "KmIDr1WCExzk"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "986 986 14 14\n"
     ]
    }
   ],
   "source": [
    "train_all_image_paths, train_all_image_labels,val_all_image_paths, val_all_image_labels = get_image_path(WORK_PATH+'dataset/train/')\n",
    "print(len(train_all_image_paths),len(train_all_image_labels),len(val_all_image_paths),len(val_all_image_labels))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "aajzP6u7MjEv"
   },
   "source": [
    "## 2、训练集数据预处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "RhBHiraiMe4r"
   },
   "outputs": [],
   "source": [
    "\n",
    "train_images_num = len(train_all_image_paths)\n",
    "train_steps_per_epoch = train_images_num//BATCH_SIZE\n",
    "train_ds = tf.data.Dataset.from_tensor_slices((train_all_image_paths, train_all_image_labels))\n",
    "train_ds = train_ds.map(load_and_preprocess_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
    "train_ds = train_ds.shuffle(buffer_size=BUFFER_SIZE)\n",
    "train_ds = train_ds.repeat()\n",
    "train_ds = train_ds.batch(BATCH_SIZE)\n",
    "train_ds = train_ds.map(decode_label, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
    "train_ds = train_ds.apply(tf.data.experimental.ignore_errors())\n",
    "train_ds = train_ds.prefetch(tf.data.experimental.AUTOTUNE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "aajzP6u7MjEv"
   },
   "source": [
    "## 3、验证集数据预处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_images_num = len(val_all_image_paths)\n",
    "val_steps_per_epoch = val_images_num//BATCH_SIZE\n",
    "val_ds = tf.data.Dataset.from_tensor_slices((val_all_image_paths, val_all_image_labels))\n",
    "val_ds = val_ds.map(load_and_preprocess_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
    "val_ds = val_ds.shuffle(buffer_size=BUFFER_SIZE)\n",
    "val_ds = val_ds.repeat()\n",
    "val_ds = val_ds.batch(BATCH_SIZE)\n",
    "val_ds = val_ds.map(decode_label, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
    "val_ds = val_ds.apply(tf.data.experimental.ignore_errors())\n",
    "val_ds = val_ds.prefetch(tf.data.experimental.AUTOTUNE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 二、模型训练"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1、模型结构"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "加载已保存模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = tf.keras.models.load_model(WORK_PATH + 'output/crnn_6.h5', compile=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "conv2d (Conv2D)              multiple                  1792      \n",
      "_________________________________________________________________\n",
      "max_pooling2d (MaxPooling2D) multiple                  0         \n",
      "_________________________________________________________________\n",
      "conv2d_1 (Conv2D)            multiple                  73856     \n",
      "_________________________________________________________________\n",
      "max_pooling2d_1 (MaxPooling2 multiple                  0         \n",
      "_________________________________________________________________\n",
      "conv2d_2 (Conv2D)            multiple                  295168    \n",
      "_________________________________________________________________\n",
      "batch_normalization (BatchNo multiple                  32        \n",
      "_________________________________________________________________\n",
      "conv2d_3 (Conv2D)            multiple                  590080    \n",
      "_________________________________________________________________\n",
      "zero_padding2d (ZeroPadding2 multiple                  0         \n",
      "_________________________________________________________________\n",
      "max_pooling2d_2 (MaxPooling2 multiple                  0         \n",
      "_________________________________________________________________\n",
      "conv2d_4 (Conv2D)            multiple                  1180160   \n",
      "_________________________________________________________________\n",
      "batch_normalization_1 (Batch multiple                  16        \n",
      "_________________________________________________________________\n",
      "conv2d_5 (Conv2D)            multiple                  2359808   \n",
      "_________________________________________________________________\n",
      "zero_padding2d_1 (ZeroPaddin multiple                  0         \n",
      "_________________________________________________________________\n",
      "max_pooling2d_3 (MaxPooling2 multiple                  0         \n",
      "_________________________________________________________________\n",
      "conv2d_6 (Conv2D)            multiple                  1049088   \n",
      "_________________________________________________________________\n",
      "batch_normalization_2 (Batch multiple                  4         \n",
      "_________________________________________________________________\n",
      "reshape (Reshape)            multiple                  0         \n",
      "_________________________________________________________________\n",
      "bidirectional (Bidirectional multiple                  1574912   \n",
      "_________________________________________________________________\n",
      "bidirectional_1 (Bidirection multiple                  1574912   \n",
      "_________________________________________________________________\n",
      "dense (Dense)                multiple                  2837403   \n",
      "=================================================================\n",
      "Total params: 11,537,231\n",
      "Trainable params: 11,537,205\n",
      "Non-trainable params: 26\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "gpBZnmh0O1uY"
   },
   "source": [
    "## 2、模型编译"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Bj3sFrXgO0us"
   },
   "outputs": [],
   "source": [
    "model.compile(optimizer=tf.keras.optimizers.Adam(0.001),\n",
    "              loss=CTCLoss(), metrics=[WordAccuracy()])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3、配置回调函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "0ImnMvZX7Ajn"
   },
   "outputs": [],
   "source": [
    "callbacks = [tf.keras.callbacks.ModelCheckpoint(WORK_PATH + 'output/crnn_{epoch}.h5',monitor='val_loss',verbose=1),\n",
    "             tf.keras.callbacks.TensorBoard(log_dir=WORK_PATH + \"logs/{}\".format(time.asctime()))]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "aajzP6u7MjEv"
   },
   "source": [
    "## 4、模型训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.fit(train_ds, \n",
    "          epochs=20, \n",
    "          steps_per_epoch=train_steps_per_epoch,\n",
    "          validation_data=val_ds,\n",
    "          validation_steps=val_steps_per_epoch,\n",
    "          initial_epoch=0,\n",
    "          callbacks = callbacks,\n",
    "          workers=8)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "authorship_tag": "ABX9TyP2Eg6pvevGDGaZ1lOjlih4",
   "collapsed_sections": [],
   "mount_file_id": "1pHFhW9EuX06p63puOUBufLkQUJm6JayS",
   "name": "ocr_dataset.ipynb",
   "private_outputs": true,
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
